import time
from trainInner import train as innertrain
from trainOuter import train as outertrain
from data import GameData, GameDataMeta
import pickle

DEFAULT_OPTIONS = {'name': 'test',
                   'save_path': './save_path/',
                   'manager_units': [50, 50],
                   'expert1_units': [1, 1],
                   'expert2_units': [1, 1],
                   'expert3_units': [1],
                   'expert4_units': [1, 1],
                   'expert5_units': [1],
                   'lr': 0.2,
                   'activ': 'relu',
                   'pooling': True,
                   'batch_size': None,
                   'ar_layers': 1,
                   'dropout': False,
                   'l1': 0.01,
                   'l2': 0.0,
                   'pooling_activ': 'max',
                   'opt': 'adam',
                   'max_itr': 10,
                   'model_seed': 3,
                   'objective': 'nll'}

def main():
    epochs = 1
    alpha = 0.005
    beta = 0.005
    D = ['0', '1', '2']             # title=1
    NE = ['0', '1', '2']            # title=2
    DC_4 = ['0', '1', '3', '4']     # title=3
    MFAR_3 = ['0', '1', '2']        # title=5
    MFAR_4 = ['0', '1', '2', '3']   # title=6
    MFAR_4 = ['0', '1', '3', '4']   # title=7
    MFAR_4 = ['0', '1', '2', '3']   # title=8
    DC_MFAR_5 = ['0', '1', '2', '3', '4']       # title=3 or 7
    MFAR_6 = ['0', '1', '2', '3', '5', '6']     # title=8
    trainLossls, testLossls = [], []
    num_task_int = len(MFAR_3)
    options = DEFAULT_OPTIONS
    '''
    The first training with MAML needs to initialize the parameters. First, trainMFAR.py runs 10 rounds to save 
    the model and uses it as the initial parameter of MAML.
    '''
    init_theta = pickle.load(open('./theta-lab1.pkl', 'rb'))    # epoch=1
    for e in range(epochs):
        trainLoss, testLoss = 0.0, 0.0
        print('\n------------------------No.{} epochs begins training------------------------------'.format(e+1))
        start_time = time.time()
        if e != 0:
            init_theta = theta_new
        # Inner loop
        for i in range(num_task_int):
            data = GameData('unique127.csv', 1., title=5, label=MFAR_3[i])
            train_data, test_data = data.train_test(0, seed=101)
            perf, par, theta, theta_new = innertrain(options, [train_data.datalist(), test_data.datalist()], alpha=alpha,
                                          maml_e=1, load_theta=init_theta)
            trainLoss += perf[0]
            testLoss += perf[1]
            init_theta = theta_new

        trainLossls.append(trainLoss)
        testLossls.append(testLoss)

        # Outer loop
        for i in range(num_task_int):
            data = GameData('unique127.csv', 1., title=5, label=MFAR_3[i])
            train_data, test_data = data.train_test(0, seed=101)
            perf, par, theta, theta_new = outertrain(options, [train_data.datalist(), test_data.datalist()],
                                                     load_params=False, beta=beta, title=1, name=i, load_theta=init_theta)
            init_theta = theta_new

        elapsed_time = time.time() - start_time
        print('\n------------------The end of the {} training, time{}s-----------------------'.format(e+1, '%.2f'% elapsed_time))
        print('\n Total Train:{}, Test:{}'.format(trainLoss, testLoss))
        if c % 100 == 0:
            pickle.dump(theta_new, open('./MFAR_3.pkl', 'wb'))

if __name__ == '__main__':
    main()






